import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import resnet50
import time
import glob
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import warnings
warnings.filterwarnings("ignore")
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    # create model
    model = resnet50(num_classes=10).to(device)
    # load model weights
    weights_path = "./resNet50.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))
    file_glob = os.path.join("test", "*." + "jpg")
    # file_glob2 = os.path.join("test", "*." + "png")
    file_list = []
    file_list.extend(glob.glob(file_glob))
    # file_list.extend(glob.glob(file_glob2))
    count=0
    for idx, filename in enumerate(file_list):
        print(filename)
        start = time.time()
        img_path = filename
        img = Image.open(img_path)
        img = img.convert('RGB')  # 将一切图片转换为RGB格式
        plt.imshow(img)
        # [N, C, H, W]
        img = data_transform(img)
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)
        # read class_indict
        json_path = './class_indices.json'
        assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
        json_file = open(json_path, "r")
        class_indict = json.load(json_file)
        model.eval()
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)
            predict_cla = torch.argmax(predict).numpy()
            end = time.time()
            print_res = "class: {}   prob: {:.3}  time: {:.4}s".format(class_indict[str(predict_cla)],
                                                         predict[predict_cla].numpy(),(end - start))
        plt.title(print_res)
        print(print_res)
        # print("推理时间：{:.4}s".format(end - start))
        # for i in range(len(predict)):
        #     print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
        #                                               predict[i].numpy()))
        plt.savefig("./output/{}.jpg".format(os.path.basename(filename)[:-4]))
        # plt.show()


if __name__ == '__main__':
    main()
